#include <os/portcom.h>
#include <sys/lpc.h>
#include <os/task.h>
#include <os/msgpool.h>
#include <os/schedule.h>
#include <os/spinlock.h>
#include <os/debug.h>
#include <lib/errno.h>
#include <lib/string.h>
#include <lib/unistd.h>

static port_com_t port_com_table[PORT_COM_MAX];
DEFINE_SPIN_LOCK(port_com_lock);
static uint32_t port_com_msgid = 0;

uint32_t PortComGenerateMsgId()
{
    return port_com_msgid++;
}

port_com_t *PortComAlloc()
{
    port_com_t *port_com;

    SpinLockDisInterrupt(&port_com_lock);
    for (int i = PORT_COM_UNNAME_START; i < PORT_COM_MAX; i++)
    {
        port_com = port_com_table + i;
        if (!port_com->flags)
        {
            port_com->flags = PORT_COM_USING;
            AtomicSet(&port_com->refer, 0);
            SpinLockInit(&port_com->lock);
            port_com->msgpool = NULL;
            SpinUnlockEnInterrupt(&port_com_lock);
            return port_com;
        }
    }
    SpinUnlockEnInterrupt(&port_com_lock);
}

int PortComFree(port_com_t *port_com)
{
    port_com_t *port;

    SpinLockDisInterrupt(&port_com_lock);
    for (int i = 0; i < PORT_COM_MAX; i++)
    {
        port = port_com_table + i;
        if (port == port_com && port->flags)
        {
            port->flags = 0;
            SpinUnlockEnInterrupt(&port_com_lock);
            return -1;
        }
    }
    SpinUnlockEnInterrupt(&port_com_lock);
    return -1;
}

port_com_t *PortComFind(uint32_t port)
{
    port_com_t *port_com;
    SpinLockDisInterrupt(&port_com_lock);
    for (int i = 0; i < PORT_COM_MAX; i++)
    {
        port_com = port_com_table + i;
        if (port_com->flags && i == port)
        {
            SpinUnlockEnInterrupt(&port_com_lock);
            return port_com;
        }
    }
}

uint32_t PortComPort2Idx(port_com_t *port_com)
{
    int idx = (port_com - port_com_table) / sizeof(port_com_t);
    return idx < PORT_COM_MAX ? idx : -1;
}

port_com_t *PortComIdx2Port(int idx)
{
    return (port_com_t *)(idx < PORT_COM_MAX ? port_com_table + idx : NULL);
}

port_com_t *PortComRealloc(uint32_t port)
{
    port_com_t *port_com;
    SpinLockDisInterrupt(&port_com_lock);
    port_com = PortComIdx2Port(port);
    if (port_com)
    {
        port_com->flags = PORT_COM_USING;
        SpinLockInit(&port_com->lock);
        AtomicSet(&port_com->refer, 0);
        port_com->msgpool = NULL;
        SpinUnlockEnInterrupt(&port_com_lock);
        return port_com;
    }
}

//vertify port and return port point
//if task no bind port then return err
//if port>0 and search port address and vertify if is task bind port
//if port<0 and vertify task if is bind port
int PortComVertify(int port, port_com_t **port_out, task_t *task)
{
    port_com_t *port_com;

    if (!task->port_comm)
        return -EPERM;
    if (port >= 0) //vertify assign port if is task bind port
    {
        port_com = PortComFind(port);
        if (!port_com)
            return -EPERM;
        if (AtomicGet(&port_com->refer) < 0)
        {
            KPrint(PRINT_ERR "port unbind: port %d not binded\n", port);
            return -EPERM;
        }
        if (!(port_com->flags & PORT_COM_GROUP))
        {
            if (port_com != task->port_comm)
            {
                return -EPERM;
            }
        }
    }
    else
    { //check if task bind port
        if (!(task->port_comm && task->port_comm->flags))
        {
            return -EPERM;
        }
        port_com = task->port_comm;
    }
    *port_out = port_com;

    return 0;
}

static void MsgPoolGetCallBack(msgpool_t *pool, void *buff)
{
    port_msg_header_t *msg_header = (port_msg_header_t *)pool->head;
    memcpy(buff, pool->head, MIN(msg_header->size, pool->msgsize));
}

//bind port to cur task and return bind port com
static port_com_t *_SysPortComBind(int port, task_t *cur, int flags)
{
    port_com_t *port_com;
    if (cur->port_comm)
    {
        if (flags & PORT_BIND_ONCE)
        {
            port_com = cur->port_comm;
        }
        else
        {
            KPrint(PRINT_ERR "[port bind]port %d had bounded on task %s\n", port, cur->name);
            return NULL;
        }
    }
    //port>0 if port present return err or alloc port
    if (port >= 0)
    {
        port_com = PortComFind(port);
        if (port_com)
        {
            if (flags & PORT_BIND_GROUP) //group port inc refer
            {
                if (AtomicGet(&port_com->refer) < 1)
                {
                    KPrint(PRINT_ERR "[port bind]port %d get refer failed!\n", port);
                    return NULL;
                }
                AtomicInc(&port_com->refer);
            }
            else
            {

                KPrint("[port bind]port %d had used!\n", port);
                return NULL;
            }
        }
        else //port no present and alloc new port
        {
            port_com = PortComRealloc(port);
            if (!port_com)
                AtomicInc(&port_com->refer);
            return NULL;
        }
    }
    else
    {
        //port<0 alloc port
        port_com = PortComAlloc();
        if (!port_com)
        {
            return NULL;
        }
        AtomicInc(&port_com->refer);
    }
    if (flags & PORT_BIND_GROUP)
        port_com->flags |= PORT_COM_GROUP;
    return port_com;
}

int SysPortComBind(int port, int flags)
{
    task_t *cur = cur_task;
    port_com_t *port_com;
    int msgcnt;

    if (IS_BAD_PORT_COM(port))
        return -EINVAL;
    //bind port 
    port_com = _SysPortComBind(port, cur, flags);
    if (!port_com)
    {
        KPrint("[port bind]port %d bind failed!\n", port);
        return -1;
    }
    SpinLockDisInterrupt(&port_com->lock);
    //if set flags PORT_BIND_ONCE and check task whether bind port,if same return
    if (flags & PORT_BIND_ONCE)
    {
        if (cur->port_comm)
        {
            if (cur->port_comm != port_com)
                KPrint("[port bind]pid=%d port=%d bind once error because not same  port!\n", cur->pid, port);
        }
        return 0;
    }

    if (flags & PORT_BIND_GROUP)
    {
        //check if is valid group port
        if (AtomicGet(&port_com->refer) > 1)
        {
            TASK_BIND_PORT_COM(cur, port_com);
            return 0;
        }
    }
    //when alloc new port need create msgpool
    port_com->msgpool = MsgpoolCreate(sizeof(port_msg_t), PORT_MSG_NUM);
    if (!port_com->msgpool)
    {
        SpinUnlockEnInterrupt(&port_com->lock);
        PortComFree(port_com);
        return -ENOMEM;
    }
    //first bind need assign port
    port_com->port = PortComPort2Idx(port_com);
    TASK_BIND_PORT_COM(cur, port_com);
    SpinUnlockEnInterrupt(&port_com->lock);
    return 0;
}

int SysPortComUnBind(int port)
{
    task_t *cur = cur_task;
    port_com_t *port_com;

    if (IS_BAD_PORT_COM(port))
    {
        KPrint(PRINT_ERR "[port unbind]port %d is bad port\n", port);
        return -EINVAL;
    }
    if (PortComVertify(port, &port_com, cur) < 0)
    {
        return -EPERM;
    }
    SpinLockDisInterrupt(&port_com->lock);
    if (AtomicGet(&port_com->refer) > 0)
    {
        AtomicDec(&port_com->refer);
        TASK_UNBIND_PORT_COM(cur);
        SpinUnlockEnInterrupt(&port_com->lock);
        return 0;
    }
    //close and free msgpool
    MsgpoolDestroy(port_com->msgpool);
    port_com->msgpool = NULL;
    port_com->port = -1;
    TASK_UNBIND_PORT_COM(cur);
    SpinUnlockEnInterrupt(&port_com->lock);
    PortComFree(port_com);
    return 0;
}

int SysPortComRequest(uint32_t port, port_msg_t *msg)
{
    task_t *cur = cur_task;
    port_com_t *port_com, *myport_com;
    uint32_t msgid;
    uint32_t try_count = 0;

    if (IS_BAD_PORT_COM(port))
    {
        KPrint(PRINT_ERR "%s: port %d invalid\n", __func__, port);
        return -EINVAL;
    }
    if (PortComVertify(-1, &myport_com, cur) < 0) //get cur task bind port
    {
        KPrint(PRINT_ERR "%s: port %d vertify failed!\n", __func__, port);
        return -EPERM;
    }
    //find assign port
    port_com = PortComFind(port);
    if (!port_com)
    {
        return -EPERM;
    }
    if (port_com == myport_com) //no can send to itself
    {
        KPrint(PRINT_ERR "%s: port %d can not request itself\n", __func__, port);
        return -EPERM;
    }
    //make msgid used to do vertify
    msgid = PortComGenerateMsgId();
    msg->header.id = msgid;
    msg->header.port = myport_com->port;

    //send request to port
    if (MsgpoolPut(port_com->msgpool, msg, msg->header.size) < 0)
    {
        KPrint(PRINT_ERR "%s: msg put to %d failed!\n", __func__, port);
        return -EPERM;
    }
    //try get msg from msgpool
    while (MsgpoolTryGet(myport_com->msgpool, msg, MsgPoolGetCallBack) < 0)
    {
        if (ExceptionCauseExit(&cur->exception_manager))
        {
            return -EINTR;
        }
        try_count++;
        //try count above limit and task yeild
        if (try_count > PORT_COM_RETRY_GET_MAX)
        {
            TaskYield();
            try_count = 0;
        }
    }
    //vertify receive msg id and check if had error
    if (msg->header.id != msgid)
    {
        KPrint(PRINT_WARNNING "%s: port %d msg id %d:%d invalid\n", __func__, port, msg->header.id, msgid);
        return -EPERM;
    }
    return 0;
}

int SysPortComReceive(int port, port_msg_t *msg)
{
    task_t *cur = cur_task;
    port_com_t *port_com;
    uint32_t try_count = 0;

    if (IS_BAD_PORT_COM(port))
    {
        KPrint(PRINT_ERR "%s: port %d invalid\n", __func__, port);
        return -EINVAL;
    }
    if (PortComVertify(port, &port_com, cur) < 0)
    {
        KPrint(PRINT_ERR "%s: port %d vertify failed!\n", __func__, port);
        return -EPERM;
    }
    if (!port_com->msgpool)
        return -EPERM;
    //try get msg from msgpool
    while (MsgpoolTryGet(port_com->msgpool, msg, MsgPoolGetCallBack) < 0)
    {
        if (ExceptionCauseExit(&cur->exception_manager))
        {
            KPrint("%s: port %d interrupt by exception\n", __func__, port);
            return -EINTR;
        }
        try_count++;
        if (try_count > PORT_COM_RETRY_GET_MAX)
        {
            TaskYield();
            try_count = 0;
        }
    }
    return 0;
}

int SysPortComReply(int port, port_msg_t *msg)
{
    task_t *cur = cur_task;
    port_com_t *port_com, *reply_port;

    if (IS_BAD_PORT_COM(port))
    {
        KPrint(PRINT_ERR "%s: port %d invalid\n", __func__, port);
        return -EINVAL;
    }
    //vertify port if is itselt
    if (PortComVertify(port, &port_com, cur) < 0)
    {
        KPrint(PRINT_ERR "%s: port %d vertify failed!\n", __func__, port);
        return -EPERM;
    }
    reply_port = PortComIdx2Port(msg->header.port);
    if (!reply_port->msgpool)
        return -EPERM;
    //send to reply msg to targe msgpool
    return MsgpoolPut(reply_port->msgpool, msg, msg->header.size);
}

void PortMsgReset(port_msg_t *msg)
{
    memset(msg, 0, sizeof(port_msg_t));
}

void PortMsgCopyHeader(port_msg_t *src, port_msg_t *dest)
{
    memcpy(&dest->header, &src->header, sizeof(port_msg_header_t));
}
